import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fsrl.utils import DummyLogger, WandbLogger
from tqdm.auto import trange  # noqa
from typing import Dict, List, Union, Tuple, Optional, Callable
from osrl.common.net import mlp, EnsembleLinear, StandardScaler
from torch.distributions.normal import Normal
import os
from tqdm import tqdm


class Swish(nn.Module):
    def __init__(self) -> None:
        super(Swish, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x * torch.sigmoid(x)
        return x


def soft_clamp(
    x : torch.Tensor,
    _min: Optional[torch.Tensor] = None,
    _max: Optional[torch.Tensor] = None
) -> torch.Tensor:
    # clamp tensor values while mataining the gradient
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x


class EnsembleCostModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_dims: Union[List[int], Tuple[int]],
        num_ensemble: int = 7,
        num_elites: int = 5,
        activation: nn.Module = Swish,
        weight_decays: Optional[Union[List[float], Tuple[float]]] = None,
        device: str = "cpu"
    ) -> None:
        super().__init__()

        self.num_ensemble = num_ensemble
        self.num_elites = num_elites
        self.device = torch.device(device)

        self.activation = activation()

        assert len(weight_decays) == (len(hidden_dims) + 1)

        module_list = []
        hidden_dims = [obs_dim+action_dim] + list(hidden_dims)
        if weight_decays is None:
            weight_decays = [0.0] * (len(hidden_dims) + 1)
        for in_dim, out_dim, weight_decay in zip(hidden_dims[:-1], hidden_dims[1:], weight_decays[:-1]):
            module_list.append(EnsembleLinear(in_dim, out_dim, num_ensemble, weight_decay))
        self.backbones = nn.ModuleList(module_list)

        self.output_layer = EnsembleLinear(
            hidden_dims[-1],
            1,
            num_ensemble,
            weight_decays[-1]
        )

        self.register_parameter(
            "elites",
            nn.Parameter(torch.tensor(list(range(0, self.num_elites))), requires_grad=False)
        )

        self.to(self.device)

    def forward(self, obs_action: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
        obs_action = torch.as_tensor(obs_action, dtype=torch.float32).to(self.device)
        output = obs_action
        for layer in self.backbones:
            output = self.activation(layer(output))
        cost_prob = F.sigmoid(self.output_layer(output))
        return cost_prob

    def load_save(self) -> None:
        for layer in self.backbones:
            layer.load_save()
        self.output_layer.load_save()

    def update_save(self, indexes: List[int]) -> None:
        for layer in self.backbones:
            layer.update_save(indexes)
        self.output_layer.update_save(indexes)
    
    def get_decay_loss(self) -> torch.Tensor:
        decay_loss = 0
        for layer in self.backbones:
            decay_loss += layer.get_decay_loss()
        decay_loss += self.output_layer.get_decay_loss()
        return decay_loss

    def set_elites(self, indexes: List[int]) -> None:
        assert len(indexes) <= self.num_ensemble and max(indexes) < self.num_ensemble
        self.register_parameter('elites', nn.Parameter(torch.tensor(indexes), requires_grad=False))
    
    def random_elite_idxs(self, batch_size: int) -> np.ndarray:
        idxs = np.random.choice(self.elites.data.cpu().numpy(), size=batch_size)
        return idxs

class EnsembleDynamicsModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_dims: Union[List[int], Tuple[int]],
        num_ensemble: int = 7,
        num_elites: int = 5,
        activation: nn.Module = Swish,
        weight_decays: Optional[Union[List[float], Tuple[float]]] = None,
        with_reward: bool = True,
        with_cost: bool = True,
        device: str = "cpu"
    ) -> None:
        super().__init__()

        self.num_ensemble = num_ensemble
        self.num_elites = num_elites
        self._with_reward = with_reward
        self._with_cost = with_cost
        self.device = torch.device(device)

        self.activation = activation()

        assert len(weight_decays) == (len(hidden_dims) + 1)

        module_list = []
        hidden_dims = [obs_dim+action_dim] + list(hidden_dims)
        if weight_decays is None:
            weight_decays = [0.0] * (len(hidden_dims) + 1)
        for in_dim, out_dim, weight_decay in zip(hidden_dims[:-1], hidden_dims[1:], weight_decays[:-1]):
            module_list.append(EnsembleLinear(in_dim, out_dim, num_ensemble, weight_decay))
        self.backbones = nn.ModuleList(module_list)

        self.output_layer = EnsembleLinear(
            hidden_dims[-1],
            2 * (obs_dim + self._with_reward),
            num_ensemble,
            weight_decays[-1]
        )
        if self._with_cost:
            self.cost_output_layer = EnsembleLinear(
                hidden_dims[-1],
                1,
                num_ensemble,
                weight_decays[-1]
            )


        self.register_parameter(
            "max_logvar",
            nn.Parameter(torch.ones(obs_dim + self._with_reward) * 0.5, requires_grad=True)
        )
        self.register_parameter(
            "min_logvar",
            nn.Parameter(torch.ones(obs_dim + self._with_reward) * -10, requires_grad=True)
        )

        self.register_parameter(
            "elites",
            nn.Parameter(torch.tensor(list(range(0, self.num_elites))), requires_grad=False)
        )

        self.to(self.device)

    def forward(self, obs_action: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
        obs_action = torch.as_tensor(obs_action, dtype=torch.float32).to(self.device)
        output = obs_action
        for layer in self.backbones:
            output = self.activation(layer(output))
        mean, logvar = torch.chunk(self.output_layer(output), 2, dim=-1)
        logvar = soft_clamp(logvar, self.min_logvar, self.max_logvar)
        if self._with_cost:
            cost_prob = F.sigmoid(self.cost_output_layer(output))
            return mean, logvar, cost_prob
        return mean, logvar, None

    def load_save(self) -> None:
        for layer in self.backbones:
            layer.load_save()
        self.output_layer.load_save()
        if self._with_cost:
            self.cost_output_layer.load_save()

    def update_save(self, indexes: List[int]) -> None:
        for layer in self.backbones:
            layer.update_save(indexes)
        self.output_layer.update_save(indexes)
        if self._with_cost:
            self.cost_output_layer.update_save(indexes)
    
    def get_decay_loss(self) -> torch.Tensor:
        decay_loss = 0
        for layer in self.backbones:
            decay_loss += layer.get_decay_loss()
        decay_loss += self.output_layer.get_decay_loss()
        if self._with_cost:
            decay_loss += self.cost_output_layer.get_decay_loss()
        return decay_loss

    def set_elites(self, indexes: List[int]) -> None:
        assert len(indexes) <= self.num_ensemble and max(indexes) < self.num_ensemble
        self.register_parameter('elites', nn.Parameter(torch.tensor(indexes), requires_grad=False))
    
    def random_elite_idxs(self, batch_size: int) -> np.ndarray:
        idxs = np.random.choice(self.elites.data.cpu().numpy(), size=batch_size)
        return idxs


class EnsembleDynamics(object):
    def __init__(
        self,
        model: nn.Module,
        cost_model: nn.Module,
        optim: torch.optim.Optimizer,
        cost_optim: torch.optim.Optimizer,
        scaler: StandardScaler,
        terminal_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
        use_scheduler: bool = False,
        dynamics_scheduler = None,
        cost_model_scheduler = None,
        penalty_coef: float = 0.0,
        uncertainty_mode: str = "ensemble_std",
        with_cost: bool = True,
        use_delta_obs: bool = True,
        reward_scale: float = 0.1,
        cost_scale: float = 1.0,
        cost_coef: float = 1.0,
    ) -> None:
        super().__init__()
        self.model = model
        self.cost_model = cost_model
        self.optim = optim
        self.cost_optim = cost_optim
        self.use_scheduler = use_scheduler
        self.dynamics_scheduler = dynamics_scheduler
        self.cost_model_scheduler = cost_model_scheduler

        self.scaler = scaler
        self.terminal_fn = terminal_fn
        self._penalty_coef = penalty_coef
        self._uncertainty_mode = uncertainty_mode
        self._with_cost = with_cost
        self.reward_scale = reward_scale
        self.cost_scale = cost_scale
        self.use_delta_obs = use_delta_obs
        self.cost_coef = cost_coef
        self.train_cost_model = False
        if cost_model is not None:
            self.train_cost_model = True

    @ torch.no_grad()
    def step(
        self,
        obs: np.ndarray,
        action: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]:
        "imagine single forward step"
        obs_act = np.concatenate([obs, action], axis=-1)
        obs_act = self.scaler.transform(obs_act)
        mean, logvar, _ = self.model(obs_act)
        mean = mean.cpu().numpy()
        logvar = logvar.cpu().numpy()
        if self.train_cost_model:
            cost_prob = self.cost_model(obs_act)
            cost_prob = cost_prob.cpu().numpy()
        if self.use_delta_obs:
            mean[..., :-1] += obs
        std = np.sqrt(np.exp(logvar))

        ensemble_samples = (mean + np.random.normal(size=mean.shape) * std).astype(np.float32)

        # choose one model from ensemble
        num_models, batch_size, _ = ensemble_samples.shape
        model_idxs = self.model.random_elite_idxs(batch_size)
        samples = ensemble_samples[model_idxs, np.arange(batch_size)]
        
        next_obs = samples[..., :-1]
        reward = samples[..., -1:]
        terminal = self.terminal_fn(obs, action, next_obs)
        info = {}
        info["raw_reward"] = reward
        if self.train_cost_model:
            cost_model_idxs = self.cost_model.random_elite_idxs(batch_size)
            cost_prob_samples = cost_prob[cost_model_idxs, np.arange(batch_size)]
            cost = np.where(cost_prob_samples<0.5, np.zeros_like(cost_prob_samples), np.ones_like(cost_prob_samples))
            info["raw_cost"] = cost

        if self._penalty_coef:
            if self._uncertainty_mode == "aleatoric":
                penalty = np.amax(np.linalg.norm(std, axis=2), axis=0)
            elif self._uncertainty_mode == "pairwise-diff":
                next_obses_mean = mean[..., :-1]
                next_obs_mean = np.mean(next_obses_mean, axis=0)
                diff = next_obses_mean - next_obs_mean
                penalty = np.amax(np.linalg.norm(diff, axis=2), axis=0)
            elif self._uncertainty_mode == "ensemble_std":
                next_obses_mean = mean[..., :-1]
                penalty = np.sqrt(next_obses_mean.var(0).mean(1))
            else:
                raise ValueError
            penalty = np.expand_dims(penalty, 1).astype(np.float32)
            assert penalty.shape == reward.shape
            reward = reward - self._penalty_coef * penalty
            info["penalty"] = penalty * self._penalty_coef
            if self.train_cost_model:
                penalty_cost = cost + self._penalty_coef * penalty
                info["penalty_cost"] = penalty_cost
        
        return next_obs, reward, terminal, info
    
    @ torch.no_grad()
    def safe_step(
        self,
        obs: np.ndarray,
        action: np.ndarray,
        cost_func,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]:
        "imagine single forward step"
        obs_act = np.concatenate([obs, action], axis=-1)
        obs_act = self.scaler.transform(obs_act)
        mean, logvar, _ = self.model(obs_act)
        mean = mean.cpu().numpy()
        logvar = logvar.cpu().numpy()
        if self.train_cost_model:
            cost_prob = self.cost_model(obs_act)
            cost_prob = cost_prob.cpu().numpy()
        if self.use_delta_obs:
            mean[..., :-1] += obs
        std = np.sqrt(np.exp(logvar))

        ensemble_samples = (mean + np.random.normal(size=mean.shape) * std).astype(np.float32)
        num_models, batch_size, _ = ensemble_samples.shape

        ensemble_next_obs = mean[..., :-1].reshape(num_models*batch_size, -1)
        ensemble_costs = []
        for idx in range(ensemble_next_obs.shape[0]):
            tmp_costs = cost_func(ensemble_next_obs[idx])
            ensemble_costs.append(tmp_costs)
        ensemble_costs = np.array(ensemble_costs).reshape(num_models, batch_size, 1)
        next_costs = np.max(ensemble_costs, axis=0)

        # choose one model from ensemble
        
        model_idxs = self.model.random_elite_idxs(batch_size)
        samples = ensemble_samples[model_idxs, np.arange(batch_size)]
        
        next_obs = samples[..., :-1]
        reward = samples[..., -1:]
        
        assert next_costs.shape == reward.shape, next_costs.shape

        terminal = self.terminal_fn(obs, action, next_obs)
        info = {}
        info["raw_reward"] = reward
        info["cost"] = next_costs

        if self._penalty_coef:
            if self._uncertainty_mode == "aleatoric":
                penalty = np.amax(np.linalg.norm(std, axis=2), axis=0)
            elif self._uncertainty_mode == "pairwise-diff":
                next_obses_mean = mean[..., :-1]
                next_obs_mean = np.mean(next_obses_mean, axis=0)
                diff = next_obses_mean - next_obs_mean
                penalty = np.amax(np.linalg.norm(diff, axis=2), axis=0)
            elif self._uncertainty_mode == "ensemble_std":
                next_obses_mean = mean[..., :-1]
                penalty = np.sqrt(next_obses_mean.var(0).mean(1))
            else:
                raise ValueError
            penalty = np.expand_dims(penalty, 1).astype(np.float32)
            assert penalty.shape == reward.shape
            reward = reward - self._penalty_coef * penalty
            info["penalty"] = penalty * self._penalty_coef
        
        return next_obs, reward, terminal, info
    
    @ torch.no_grad()
    def sample_next_obss(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        num_samples: int
    ) -> torch.Tensor:
        obs_act = torch.cat([obs, action], dim=-1)
        obs_act = self.scaler.transform_tensor(obs_act)
        mean, logvar, _ = self.model(obs_act)
        if self.use_delta_obs:
            mean[..., :-1] += obs
        std = torch.sqrt(torch.exp(logvar))

        mean = mean[self.model.elites.data.cpu().numpy()]
        std = std[self.model.elites.data.cpu().numpy()]

        samples = torch.stack([mean + torch.randn_like(std) * std for i in range(num_samples)], 0)
        next_obss = samples[..., :-1]
        return next_obss

    def format_samples_for_training(self, data: Dict) -> Tuple[np.ndarray, np.ndarray]:
        obss = data["observations"]
        actions = data["actions"]
        next_obss = data["next_observations"]
        rewards = (data["rewards"] * self.reward_scale).reshape(-1,1)
        costs = (data["costs"] * self.cost_scale).reshape(-1,1)
        if self.use_delta_obs:
            delta_obss = next_obss - obss
        else:
            delta_obss = next_obss
        inputs = np.concatenate((obss, actions), axis=-1)
        if self._with_cost or self.train_cost_model:
            targets = np.concatenate((delta_obss, rewards, costs), axis=-1)
        else:
            targets = np.concatenate((delta_obss, rewards), axis=-1)
        return inputs, targets

    def train(
        self,
        data: Dict,
        logger,
        max_epochs: Optional[float] = None,
        max_epochs_since_update: int = 5,
        batch_size: int = 256,
        holdout_ratio: float = 0.2,
        logvar_loss_coef: float = 0.01
    ) -> None:
        inputs, targets = self.format_samples_for_training(data)
        data_size = inputs.shape[0]
        holdout_size = min(int(data_size * holdout_ratio), 1000)
        train_size = data_size - holdout_size
        train_splits, holdout_splits = torch.utils.data.random_split(range(data_size), (train_size, holdout_size))
        train_inputs, train_targets = inputs[train_splits.indices], targets[train_splits.indices]
        holdout_inputs, holdout_targets = inputs[holdout_splits.indices], targets[holdout_splits.indices]

        self.scaler.fit(train_inputs)
        train_inputs = self.scaler.transform(train_inputs)
        holdout_inputs = self.scaler.transform(holdout_inputs)
        holdout_losses = [1e10 for i in range(self.model.num_ensemble)]

        data_idxes = np.random.randint(train_size, size=[self.model.num_ensemble, train_size])
        def shuffle_rows(arr):
            idxes = np.argsort(np.random.uniform(size=arr.shape), axis=-1)
            return arr[np.arange(arr.shape[0])[:, None], idxes]

        ### train_dynamics_model
        epoch = 0
        cnt = 0
        logger.log("Training dynamics:")
        while True:
            epoch += 1
            train_loss, train_cost_loss = self.learn(train_inputs[data_idxes], train_targets[data_idxes], batch_size, logvar_loss_coef)
            new_holdout_losses, new_holdout_cost_losses = self.validate(holdout_inputs, holdout_targets)
            holdout_loss = (np.sort(new_holdout_losses)[:self.model.num_elites]).mean()
            if self._with_cost:
                holdout_cost_loss = (np.sort(new_holdout_cost_losses)[:self.model.num_elites]).mean()
            logger.logkv("loss/dynamics_train_loss", train_loss)
            logger.logkv("loss/dynamics_holdout_loss", holdout_loss)
            logger.logkv("update/dynamics_lr", self.dynamics_scheduler.get_last_lr()[0])
            if self._with_cost:
                logger.logkv("loss/cost_train_loss", train_cost_loss)
                logger.logkv("loss/cost_holdout_loss", holdout_cost_loss)
                logger.logkv("loss/obsrew_train_loss", train_loss-self.cost_coef * train_cost_loss)
                logger.logkv("loss/obsrew_holdout_loss", holdout_loss-self.cost_coef * holdout_cost_loss)
            else:
                logger.logkv("loss/obsrew_train_loss", train_loss)
                logger.logkv("loss/obsrew_holdout_loss", holdout_loss)
            logger.set_timestep(epoch)
            logger.dumpkvs(exclude=["policy_training_progress"])

            # shuffle data for each base learner
            data_idxes = shuffle_rows(data_idxes)

            indexes = []
            for i, new_loss, old_loss in zip(range(len(holdout_losses)), new_holdout_losses, holdout_losses):
                improvement = (old_loss - new_loss) / old_loss
                if improvement > 0.01:
                    indexes.append(i)
                    holdout_losses[i] = new_loss
            
            if len(indexes) > 0:
                self.model.update_save(indexes)
                cnt = 0
            else:
                cnt += 1
            
            if (cnt >= max_epochs_since_update) or (max_epochs and (epoch >= max_epochs)):
                break

        indexes = self.select_elites(holdout_losses)
        self.model.set_elites(indexes)
        self.model.load_save()
        self.save(logger.model_dir)
        self.model.eval()
        logger.log("dynamics elites:{} , holdout loss: {}".format(indexes, (np.sort(holdout_losses)[:self.model.num_elites]).mean()))


        ### train_cost_model
        if self.train_cost_model:
            holdout_cost_losses = [1e10 for i in range(self.cost_model.num_ensemble)]
            epoch = 0
            cnt = 0
            logger.log("Training cost model:")
            while True:
                epoch += 1
                train_cost_model_loss = self.learn_cost(train_inputs[data_idxes], train_targets[data_idxes], batch_size, logvar_loss_coef)
                new_holdout_cost_model_losses = self.validate_cost(holdout_inputs, holdout_targets)
                holdout_loss = (np.sort(new_holdout_cost_model_losses)[:self.cost_model.num_elites]).mean()
                logger.logkv("loss/cost_model_train_loss", train_cost_model_loss)
                logger.logkv("loss/cost_model_holdout_loss", holdout_loss)
                logger.logkv("update/cost_model_lr", self.cost_model_scheduler.get_last_lr()[0])
                logger.set_timestep(epoch)
                logger.dumpkvs(exclude=["policy_training_progress"])

                # shuffle data for each base learner
                data_idxes = shuffle_rows(data_idxes)

                indexes = []
                for i, new_loss, old_loss in zip(range(len(holdout_cost_losses)), new_holdout_cost_model_losses, holdout_cost_losses):
                    improvement = (old_loss - new_loss) / old_loss
                    if improvement > 0.01:
                        indexes.append(i)
                        holdout_cost_losses[i] = new_loss
                
                if len(indexes) > 0:
                    self.cost_model.update_save(indexes)
                    cnt = 0
                else:
                    cnt += 1
                
                if (cnt >= max_epochs_since_update) or (max_epochs and (epoch >= max_epochs)):
                    break

            indexes = self.select_elites(holdout_cost_losses)
            self.cost_model.set_elites(indexes)
            self.cost_model.load_save()
            self.save_cost(logger.model_dir)
            self.cost_model.eval()
            logger.log("cost model elites:{} , holdout loss: {}".format(indexes, (np.sort(holdout_cost_losses)[:self.cost_model.num_elites]).mean()))
    
    def learn(
        self,
        inputs: np.ndarray,
        targets: np.ndarray,
        batch_size: int = 256,
        logvar_loss_coef: float = 0.01
    ) -> float:
        self.model.train()
        train_size = inputs.shape[1]
        losses = []
        cost_losses = []
        bce_loss = nn.BCELoss(reduction='none')
        for batch_num in tqdm(range(int(np.ceil(train_size / batch_size)))):
            inputs_batch = inputs[:, batch_num * batch_size:(batch_num + 1) * batch_size]
            targets_batch = targets[:, batch_num * batch_size:(batch_num + 1) * batch_size]
            targets_batch = torch.as_tensor(targets_batch).to(self.model.device)
            if self._with_cost or self.train_cost_model:
                targets_cost_batch = targets_batch[...,-1:]
                targets_batch = targets_batch[...,:-1]
            mean, logvar, cost_prob = self.model(inputs_batch)
            inv_var = torch.exp(-logvar)
            # Average over batch and dim, sum over ensembles.
            mse_loss_inv = (torch.pow(mean - targets_batch, 2) * inv_var).mean(dim=(1, 2))
            var_loss = logvar.mean(dim=(1, 2))
            loss = mse_loss_inv.sum() + var_loss.sum()
            loss = loss + self.model.get_decay_loss()
            loss = loss + logvar_loss_coef * self.model.max_logvar.sum() - logvar_loss_coef * self.model.min_logvar.sum()
            if self._with_cost:
                cost_loss = bce_loss(cost_prob, targets_cost_batch).mean(dim=(1,2))
                loss += self.cost_coef * cost_loss.sum()

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            if self.use_scheduler:
                self.dynamics_scheduler.step()

            losses.append(loss.item())
            if self._with_cost:
                cost_losses.append(cost_loss.sum().item())
            else:
                cost_losses.append(0)
        return np.mean(losses), np.mean(cost_losses)

    def learn_cost(
        self,
        inputs: np.ndarray,
        targets: np.ndarray,
        batch_size: int = 256,
        logvar_loss_coef: float = 0.01
    ) -> float:
        self.cost_model.train()
        train_size = inputs.shape[1]
        cost_losses = []
        bce_loss = nn.BCELoss(reduction='none')
        for batch_num in tqdm(range(int(np.ceil(train_size / batch_size)))):
            inputs_batch = inputs[:, batch_num * batch_size:(batch_num + 1) * batch_size]
            targets_batch = targets[:, batch_num * batch_size:(batch_num + 1) * batch_size]
            targets_batch = torch.as_tensor(targets_batch).to(self.cost_model.device)
            targets_cost_batch = targets_batch[...,-1:]
            targets_batch = targets_batch[...,:-1]
            cost_prob = self.cost_model(inputs_batch)
            cost_loss = bce_loss(cost_prob, targets_cost_batch).mean(dim=(1,2))
            loss = cost_loss.sum()

            self.cost_optim.zero_grad()
            loss.backward()
            self.cost_optim.step()
            if self.use_scheduler:
                self.cost_model_scheduler.step()

            cost_losses.append(loss.item())
        return np.mean(cost_losses)
    
    @ torch.no_grad()
    def validate(self, inputs: np.ndarray, targets: np.ndarray) -> List[float]:
        self.model.eval()
        bce_loss = nn.BCELoss(reduction='none')
        mean, _, cost_prob = self.model(inputs)
        targets = torch.as_tensor(targets).to(self.model.device)
        targets = targets.unsqueeze(0).repeat(mean.shape[0],1,1)
        if self._with_cost or self.train_cost_model:
            targets_cost = targets[...,-1:]
            targets = targets[...,:-1]
        # print(mean.shape, cost_prob.shape, targets.shape, targets_cost.shape)
        loss = ((mean - targets) ** 2).mean(dim=(1, 2))
        if self._with_cost:
            cost_loss = bce_loss(cost_prob, targets_cost).mean(dim=(1,2))
            loss += self.cost_coef * cost_loss
        val_loss = list(loss.cpu().numpy())
        if self._with_cost:
            val_cost_loss = list(cost_loss.cpu().numpy())
        else:
            val_cost_loss = None
        return val_loss, val_cost_loss
    
    @ torch.no_grad()
    def validate_cost(self, inputs: np.ndarray, targets: np.ndarray) -> List[float]:
        self.cost_model.eval()
        bce_loss = nn.BCELoss(reduction='none')
        cost_prob = self.cost_model(inputs)
        targets = torch.as_tensor(targets).to(self.cost_model.device)
        targets = targets.unsqueeze(0).repeat(cost_prob.shape[0],1,1)
        targets_cost = targets[...,-1:]
        targets = targets[...,:-1]
        # print(mean.shape, cost_prob.shape, targets.shape, targets_cost.shape)
        cost_loss = bce_loss(cost_prob, targets_cost).mean(dim=(1,2))
        loss = cost_loss
        val_cost_loss = list(loss.cpu().numpy())
        return val_cost_loss
    
    def select_elites(self, metrics: List) -> List[int]:
        #assert self.cost_model.num_elites == self.model.num_elites
        pairs = [(metric, index) for metric, index in zip(metrics, range(len(metrics)))]
        pairs = sorted(pairs, key=lambda x: x[0])
        elites = [pairs[i][1] for i in range(self.model.num_elites)]
        return elites

    def save(self, save_path: str) -> None:
        torch.save(self.model.state_dict(), os.path.join(save_path, "dynamics.pth"))
        self.scaler.save_scaler(save_path)
    
    def save_cost(self, save_path: str) -> None:
        torch.save(self.cost_model.state_dict(), os.path.join(save_path, "cost_model.pth"))
    
    def load(self, load_path: str) -> None:
        self.model.load_state_dict(torch.load(os.path.join(load_path, "dynamics.pth"), map_location=self.model.device))
        self.scaler.load_scaler(load_path)
    
    def load_cost(self, load_path: str) -> None:
        self.cost_model.load_state_dict(torch.load(os.path.join(load_path, "cost_model.pth"), map_location=self.cost_model.device))
